package com.rapidminer.test.utils;
import java.util.Iterator;
import java.util.List;
import org.junit.ComparisonFailure;
import junit.framework.Assert;
import junit.framework.AssertionFailedError;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.performance.PerformanceCriterion;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.visualization.dependencies.NumericalMatrix;
import com.rapidminer.tools.math.Averagable;
import com.rapidminer.tools.math.AverageVector;
/**
* Extension for JUnit's Assert for testing RapidMiner objects.
*
* @author Simon Fischer, Marcin Skirzynski
*
*/
public class RapidAssert extends Assert {
public static final double DELTA = 0.000000001;
/**
* Extends the Junit assertEquals method by additionally checking the doubles for NaN.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
*/
public static void assertEqualsNaN(String message, double expected, double actual) {
if (Double.isNaN(expected)) {
if (!Double.isNaN(actual)) {
throw new AssertionFailedError(message + " expected: <" + expected + "> but was: <" + actual + ">");
}
} else {
assertEquals(message, expected, actual, DELTA);
}
}
/**
* Tests if the special names of the attribute roles are equal and the associated attributes themselves.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals(String message, AttributeRole expected, AttributeRole actual) {
Assert.assertEquals(message + " (attribute role)", expected.getSpecialName(), actual.getSpecialName());
Attribute a1 = expected.getAttribute();
Attribute a2 = actual.getAttribute();
assertEquals(message, a1, a2);
}
/**
* Tests two attributes by using the name, type, block, type, default value and the nominal mapping
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals(String message, Attribute expected, Attribute actual) {
Assert.assertEquals(message + " (attribute name)", expected.getName(), actual.getName());
Assert.assertEquals(message + " (attribute type)", expected.getValueType(), actual.getValueType());
Assert.assertEquals(message + " (attribute block type)", expected.getBlockType(), actual.getBlockType());
Assert.assertEquals(message + " (default value)", expected.getDefault(), actual.getDefault());
if (expected.isNominal()) {
assertEquals(message + " (nominal mapping)", expected.getMapping(), actual.getMapping());
}
}
/**
* Tests two nominal mappings for its size and values.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals(String message, NominalMapping expected, NominalMapping actual) {
Assert.assertEquals(message + " (nominal mapping size)", expected.size(), actual.size());
List<String> v1 = expected.getValues();
List<String> v2 = actual.getValues();
Assert.assertEquals(message + " (nominal values)", v1, v2);
if (v1 != null) { // v2 also != null
for (String value : v1) {
Assert.assertEquals(message + " (index of nominal value '" + value + "')", expected.getIndex(value), actual.getIndex(value));
}
}
}
/**
* Tests two example sets by iterating over all examples until the number of rows to consider are reached. If
* this number is -1 there will be no limitation.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
* @param numberOfRowsToConsider number of examples to consider for the test. -1 means: No limit!
*/
public static void assertEquals(String message, ExampleSet es1, ExampleSet es2, int numberOfRowsToConsider) {
if (numberOfRowsToConsider == -1) {
numberOfRowsToConsider = Integer.MAX_VALUE;
}
assertEquals(message, es1.getAttributes(), es2.getAttributes());
Assert.assertEquals(message + " (number of examples)", es1.size(), es2.size());
Iterator<Example> i1 = es1.iterator();
Iterator<Example> i2 = es2.iterator();
int row = 0;
while (i1.hasNext() && i2.hasNext() && (row < numberOfRowsToConsider)) {
assertEquals(message, i1.next(), i2.next(), es1.getAttributes().allAttributes(), es2.getAttributes().allAttributes(), row);
row++;
}
}
/**
* Test two numerical matrices for equality. This contains tests about the number of columns and rows, as well as column&row names and if
* the matrices are marked as symmetrical and if every value within the matrix is equal.
*
* @param message message to display if an error occurs
* @param expected expected matrix
* @param actual actual matrix
*/
public static void assertEquals(String message, NumericalMatrix expected, NumericalMatrix actual) {
int expNrOfCols = expected.getNumberOfColumns();
int actNrOfCols = actual.getNumberOfColumns();
assertEquals(message + " (column number is not equal)", expNrOfCols, actNrOfCols);
int expNrOfRows = expected.getNumberOfRows();
int actNrOfRows = actual.getNumberOfRows();
assertEquals(message + " (row number is not equal)", expNrOfRows, actNrOfRows);
int cols = expNrOfCols;
int rows = expNrOfRows;
for( int col=0; col<cols; col++ ) {
String expectedColName = expected.getColumnName(col);
String actualColName = actual.getColumnName(col);
assertEquals(message + " (column name at index "+col+" is not equal)", expectedColName, actualColName );
}
for( int row=0; row<rows; row++ ) {
String expectedRowName = expected.getRowName(row);
String actualRowName = actual.getRowName(row);
assertEquals(message + " (row name at index "+row+" is not equal)", expectedRowName, actualRowName );
}
assertEquals(message + " (matrix symmetry is not equal)", expected.isSymmetrical(), actual.isSymmetrical());
for( int row=0; row<rows; row++ ) {
for( int col=0; col<cols; col++ ) {
double expectedVal = expected.getValue(row, col);
double actualVal = actual.getValue(row, col);
assertEquals(message + " (value at row "+row+" and column "+col+" is not equal)", expectedVal, actualVal );
}
}
}
/**
* Tests the two average vectors for equality by testing the size and each averagable.
*
* @param message message to display if an error occurs
* @param expected expected vector
* @param actual actual vector
*/
public static void assertEquals(String message, AverageVector expected, AverageVector actual) {
int expSize = expected.getSize();
int actSize = actual.getSize();
assertEquals(message + " (size of the average vector is not equal)", expSize, actSize);
int size = expSize;
for( int i=0; i<size; i++ ) {
RapidAssert.assertEquals(message, expected.getAveragable(i), actual.getAveragable(i));
}
}
/**
* Tests the two performance vectors for equality by testing the size, the criteria names, the main criterion and each criterion.
*
* @param message message to display if an error occurs
* @param expected expected vector
* @param actual actual vector
*/
public static void assertEquals(String message, PerformanceVector expected, PerformanceVector actual) {
int expSize = expected.getSize();
int actSize = actual.getSize();
assertEquals(message + " (size of the performance vector is not equal)", expSize, actSize);
int size = expSize;
RapidAssert.assertArrayEquals(message, expected.getCriteriaNames(), actual.getCriteriaNames());
RapidAssert.assertEquals(message, expected.getMainCriterion(), actual.getMainCriterion());
for( int i=0; i<size; i++ ) {
RapidAssert.assertEquals(message, expected.getCriterion(i), actual.getCriterion(i));
}
}
/**
* Tests for equality by testing all averages, standard deviation and variances.
*
* @param message message to display if an error occurs
* @param expected expected averagable
* @param actual actual averagable
*/
public static void assertEquals(String message, Averagable expected, Averagable actual) {
assertEquals(message + " (average is not equal)", expected.getAverage(), actual.getAverage());
assertEquals(message + " (makro average is not equal)", expected.getMakroAverage(), actual.getMakroAverage());
assertEquals(message + " (mikro average is not equal)", expected.getMikroAverage(), actual.getMikroAverage());
assertEquals(message + " (average count is not equal)", expected.getAverageCount(), actual.getAverageCount());
assertEquals(message + " (makro standard deviation is not equal)", expected.getMakroStandardDeviation(), actual.getMakroStandardDeviation());
assertEquals(message + " (mikro standard deviation is not equal)", expected.getMikroStandardDeviation(), actual.getMikroStandardDeviation());
assertEquals(message + " (standard deviation is not equal)", expected.getStandardDeviation(), actual.getStandardDeviation());
assertEquals(message + " (makro variance is not equal)", expected.getMakroVariance(), actual.getMakroVariance());
assertEquals(message + " (mikro variance is not equal)", expected.getMikroVariance(), actual.getMikroVariance());
assertEquals(message + " (variance is not equal)", expected.getVariance(), actual.getVariance());
}
/**
* Tests for equality by testing all averages, standard deviation and variances, as well as the fitness, max fitness
* and example count.
*
* @param message message to display if an error occurs
* @param expected expected criterion
* @param actual actual criterion
*/
public static void assertEquals(String message, PerformanceCriterion expected, PerformanceCriterion actual) {
RapidAssert.assertEquals(message , (Averagable)expected, (Averagable)actual);
assertEquals(message + " (fitness is not equal)", expected.getFitness(), actual.getFitness());
assertEquals(message + " (max fitness is not equal)", expected.getMaxFitness(), actual.getMaxFitness());
assertEquals(message + " (example count is not equal)", expected.getExampleCount(), actual.getExampleCount());
}
/**
* Tests the two examples by testing the value of the examples for every given attribute.
* This method is sensitive to the attribute ordering.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
* @param expectedAttributesToConsider an iterator for the attributes to consider for the expected example
* @param actualAttributesToConsider an iterator for the attributes to consider for the actual example
* @param row current row of the example set
*/
private static void assertEquals(String message, Example expected, Example actual,
Iterator<Attribute> expectedAttributesToConsider, Iterator<Attribute> actualAttributesToConsider, int row) {
while (expectedAttributesToConsider.hasNext() && actualAttributesToConsider.hasNext()) {
Attribute a1 = expectedAttributesToConsider.next();
Attribute a2 = actualAttributesToConsider.next();
if (!a1.getName().equals(a2.getName())) {
// this should have been detected by previous checks already
throw new AssertionFailedError("Attribute ordering does not match: " + a1.getName() + "," + a2.getName());
}
if (a1.isNominal()) {
Assert.assertEquals(message + " (example " + (row + 1) + ", nominal attribute value " + a1.getName() + ")", expected.getNominalValue(a1), actual.getNominalValue(a2));
} else {
Assert.assertEquals(message + " (example " + (row + 1) + ", numerical attribute value " + a1.getName() + ")", expected.getValue(a1), actual.getValue(a2), DELTA);
}
}
}
/**
* Tests if all attributes are equal. This method is sensitive to the attribute ordering.
*
* @param message message to display if an error occurs
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals(String message, Attributes expected, Attributes actual) {
Assert.assertEquals(message + " (number of attributes)", expected.allSize(), actual.allSize());
Iterator<AttributeRole> i = expected.allAttributeRoles();
Iterator<AttributeRole> j = expected.allAttributeRoles();
while (i.hasNext()) {
AttributeRole r1 = i.next();
AttributeRole r2 = j.next();
assertEquals(message, r1, r2);
}
}
/**
* Tests all objects in the array.
*
* @param expected array with expected objects
* @param actual array with actual objects
*/
public static void assertArrayEquals(String message, Object[] expected, Object[] actual) {
if (expected == null) {
junit.framework.Assert.assertEquals((Object) null, actual);
return;
}
if (actual == null) {
throw new AssertionFailedError(message + " (expected " + expected.toString() + " , but is null)");
}
junit.framework.Assert.assertEquals(message + " (array length is not equal)", expected.length, actual.length);
for (int i = 0; i < expected.length; i++) {
junit.framework.Assert.assertEquals(message, expected[i], actual[i]);
}
}
/**
* Tests all objects in the array.
*
* @param message message to display if an error occurs
* @param expected array with expected objects
* @param actual array with actual objects
*/
public static void assertArrayEquals(Object[] expected, Object[] actual) {
assertArrayEquals("", expected, actual);
}
/**
* Tests the collection of ioobjects
*
* @param expected
* @param actual
*/
public static void assertEquals(String message, IOObjectCollection<IOObject> expected, IOObjectCollection<IOObject> actual) {
assertEquals("Number of IOObjects in collections are not equal", expected.size(), actual.size());
RapidAssert.assertEquals(message, expected.getObjects(), actual.getObjects());
}
/**
* Tests if both list of ioobjects are equal.
*
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals( String message, List<IOObject> expected, List<IOObject> actual ) {
assertSize(expected, actual);
Iterator<IOObject> expectedIter = expected.iterator();
Iterator<IOObject> actualIter = actual.iterator();
while( expectedIter.hasNext() && actualIter.hasNext() ) {
IOObject expectedIOO = expectedIter.next();
IOObject actualIOO = actualIter.next();
assertEquals(message, expectedIOO, actualIOO);
}
}
/**
* Tests if both list of ioobjects are equal.
*
* @param expected expected value
* @param actual actual value
*/
public static void assertEquals( List<IOObject> expected, List<IOObject> actual ) {
RapidAssert.assertEquals("", expected, actual);
}
/**
* Tests if both lists of IOObjects have the same size.
*
* @param expected
* @param actual
*/
public static void assertSize( List<IOObject> expected, List<IOObject> actual ) {
assertEquals("Number of connected output ports in the process is not equal with the number of ioobjects contained in the same folder with the format 'processname-expected-port-1', 'processname-expected-port-2', ...",
expected.size(), actual.size());
}
/**
* Tests if the two IOObjects are equal.
*
* @param expectedIOO
* @param actualIOO
*/
public static void assertEquals(IOObject expectedIOO, IOObject actualIOO) {
RapidAssert.assertEquals("", expectedIOO, actualIOO);
}
/**
* Tests if the two IOObjects are equal and appends the given message.
*
* @param expectedIOO
* @param actualIOO
*/
public static void assertEquals(String message, IOObject expectedIOO, IOObject actualIOO) {
if( expectedIOO instanceof ExampleSet && actualIOO instanceof ExampleSet )
RapidAssert.assertEquals(message + "ExampleSets are not equal", (ExampleSet)expectedIOO, (ExampleSet)actualIOO, -1);
else if( expectedIOO instanceof NumericalMatrix && actualIOO instanceof NumericalMatrix )
RapidAssert.assertEquals(message + "Numerical matrices are not equal", (NumericalMatrix) expectedIOO, (NumericalMatrix) actualIOO);
else if( expectedIOO instanceof IOObjectCollection ) {
@SuppressWarnings("unchecked")
IOObjectCollection<IOObject> expectedCollection = (IOObjectCollection<IOObject>)expectedIOO;
@SuppressWarnings("unchecked")
IOObjectCollection<IOObject> actualCollection = (IOObjectCollection<IOObject>)actualIOO;
RapidAssert.assertEquals(message + "Collection of IOObjects are not equal: ", expectedCollection, actualCollection);
}
else if( expectedIOO instanceof PerformanceVector && actualIOO instanceof PerformanceVector )
RapidAssert.assertEquals(message + "Performance vectors are not equal", (PerformanceVector) expectedIOO, (PerformanceVector) actualIOO);
else if( expectedIOO instanceof AverageVector && actualIOO instanceof AverageVector )
RapidAssert.assertEquals(message + "Average vectors are not equals", (AverageVector) expectedIOO, (AverageVector) actualIOO);
else {
throw new ComparisonFailure("Comparison of the two given IOObject classes is not supported yet", expectedIOO.toString(), actualIOO.toString());
}
}
}